{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZyFAnMovQVdt"
      },
      "source": [
        "Copyright 2022 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pBLAnSoVOsGG"
      },
      "source": [
        "# Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TF8c2SYmZMjE"
      },
      "outputs": [],
      "source": [
        "# Copy google-research:\n",
        "!git clone https://github.com/google-research/google-research.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cF30emzaIiiB"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import os\n",
        "import numpy as np\n",
        "import math\n",
        "import scipy.io.wavfile as wav\n",
        "import scipy.signal as signal\n",
        "from matplotlib import pylab as plt\n",
        "\n",
        "use_colabtools = False\n",
        "if use_colabtools:\n",
        "  import colabtools.sound\n",
        "  from colabtools import sound\n",
        "else:\n",
        "  import IPython"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0dHG3zo5l0KS"
      },
      "outputs": [],
      "source": [
        "# It was validated with TF 2.9.1\n",
        "print(tf.__version__)\n",
        "assert tf.__version__=='2.9.1'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m9xaiF05OunL"
      },
      "source": [
        "# Utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oEK2dNf6Iw7D"
      },
      "outputs": [],
      "source": [
        "def WavRead(filename, divide=False, target_sample_rate=16000):\n",
        "  \"\"\"Read in audio data from a wav file.  Return d, sr.\"\"\"\n",
        "  normalizer = {\n",
        "      'int32': 2147483648.0,\n",
        "      'int16': 32768.0,\n",
        "      'float32': 1.0,\n",
        "      }\n",
        "  samplerate, wave_data = wav.read(filename)\n",
        "  norm = normalizer[wave_data.dtype.name]\n",
        "  if samplerate != target_sample_rate:\n",
        "    desired_length = int(\n",
        "        round(float(len(wave_data)) / samplerate * target_sample_rate))\n",
        "    wave_data = signal.resample(wave_data, desired_length)\n",
        "    print(\"resample input wav samplerate \" + str(samplerate))\n",
        "\n",
        "  # Normalize floats in range [-1..1).\n",
        "  data = np.array(wave_data, np.float32) / norm\n",
        "\n",
        "  return data, target_sample_rate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jVVI6s77KM5k"
      },
      "outputs": [],
      "source": [
        "def RunNonStreaming(input_features, tflite_model_path):\n",
        "  \"\"\"Runs tflite_model in non streaming mode.\n",
        "\n",
        "  It relies on assumption that tflite inputs/outputs are set in order and we can\n",
        "  access them by index.\n",
        "\n",
        "  Arguments:\n",
        "    input_features: input features\n",
        "    tflite_model_path: path to tflite model\n",
        "\n",
        "  Returns:\n",
        "    Output produced by non streaming.\n",
        "  \"\"\"\n",
        "\n",
        "  interpreter = tf.lite.Interpreter(model_path=tflite_model_path)\n",
        "  input_details = interpreter.get_input_details()\n",
        "  output_details = interpreter.get_output_details()\n",
        "  input_shape_signature = input_details[0]['shape_signature']\n",
        "\n",
        "  if -1 in input_shape_signature:\n",
        "    interpreter.resize_tensor_input(input_details[0]['index'], input_features.shape)\n",
        "\n",
        "  interpreter.allocate_tensors()\n",
        "  interpreter.set_tensor(input_details[0]['index'], input_features)\n",
        "  interpreter.invoke()\n",
        "  non_stream_output = interpreter.get_tensor(output_details[0]['index'])\n",
        "\n",
        "  return non_stream_output\n",
        "\n",
        "\n",
        "def RunStreaming(input_features, step, tflite_model_path, inp_to_out, input_index=0, padding_index=-1):\n",
        "  \"\"\"Runs tflite_model in streaming mode.\n",
        "\n",
        "  It relies on assumption that tflite inputs/outputs are set in order and we can\n",
        "  access them by index.\n",
        "\n",
        "  Arguments:\n",
        "    input_features: input features\n",
        "    step: stride to process input data\n",
        "    tflite_model_path: path to tflite model\n",
        "    input_index: index of input data in TFLite module\n",
        "    padding_index: index of padding data in TFLite module.\n",
        "      It is optional: if -1 then ignored.\n",
        "\n",
        "  Returns:\n",
        "    Output produced by streaming: it is a concatenation of outputs produced\n",
        "     by streaming mode, so that we can compare it with non streaming output\n",
        "  \"\"\"\n",
        "\n",
        "  interpreter = tf.lite.Interpreter(model_path=tflite_model_path)\n",
        "  interpreter.allocate_tensors()\n",
        "\n",
        "  input_details = interpreter.get_input_details()\n",
        "  output_details = interpreter.get_output_details()\n",
        "\n",
        "  # create input states\n",
        "  input_states = []\n",
        "  for s in range(len(input_details)):\n",
        "    input_states.append(\n",
        "        np.zeros(input_details[s]['shape'], dtype=input_details[s]['dtype']))\n",
        "\n",
        "  stream_features = None\n",
        "\n",
        "  start = 0\n",
        "  end = step\n",
        "  while end \u003c= input_features.shape[1]:\n",
        "    input_packet = input_features[:, start:end]\n",
        "    paddings_packet = tf.zeros(input_packet.shape[0:2])\n",
        "\n",
        "    # update indexes of streamed updates\n",
        "    start = end\n",
        "    end += step\n",
        "\n",
        "    # set input audio data (by default input data at index 0, 1)\n",
        "    interpreter.set_tensor(input_details[input_index]['index'], input_packet)\n",
        "    if padding_index \u003e 0:\n",
        "      interpreter.set_tensor(input_details[padding_index]['index'], paddings_packet)\n",
        "\n",
        "    # set input states\n",
        "    for s in range(len(input_details)):\n",
        "      if s not in [input_index, padding_index]:\n",
        "        interpreter.set_tensor(input_details[s]['index'], input_states[s])\n",
        "\n",
        "    # run inference\n",
        "    interpreter.invoke()\n",
        "\n",
        "    # get output data (and ignore output paddings)\n",
        "    stream_output = interpreter.get_tensor(output_details[inp_to_out[input_index]]['index'])\n",
        "\n",
        "\n",
        "    # get output states and set it back to input states\n",
        "    # which will be fed in the next inference cycle\n",
        "    for s in range(len(input_details)):\n",
        "      # The function `get_tensor()` returns a copy of the tensor data.\n",
        "      # Use `tensor()` in order to get a pointer to the tensor.\n",
        "      if s not in [input_index, padding_index]:\n",
        "        input_states[s] = interpreter.get_tensor(output_details[inp_to_out[s]]['index'])\n",
        "\n",
        "    if stream_features is None:\n",
        "      stream_features = stream_output\n",
        "    else:\n",
        "      stream_features = tf.concat((stream_features, stream_output), axis=1)\n",
        "\n",
        "  return stream_features\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CJnJDKvVJGjt"
      },
      "outputs": [],
      "source": [
        "def Wav2Spectrogram(wav_data):\n",
        "  frame_size_ms = 50.0\n",
        "  frame_step_ms = 12.5\n",
        "  sample_rate = 16000\n",
        "\n",
        "  frame_step = int(round(sample_rate * frame_step_ms / 1000.0))\n",
        "  frame_size = int(round(sample_rate * frame_size_ms / 1000.0))\n",
        "\n",
        "  input_features = tf.expand_dims(wav_data, 0)\n",
        "\n",
        "  # Preempasis\n",
        "  preemph = 0.97\n",
        "  pad = [[0, 0]] * input_features.shape.rank\n",
        "  pad[1] = [1, 0]  # Pad on the left side, becasue of preemphasis\n",
        "  input_features = tf.pad(input_features, pad, 'constant')\n",
        "  preemph_features = input_features[:, 1:] - preemph * input_features[:, :-1]\n",
        "\n",
        "  # Framing\n",
        "  framed_features = tf.signal.frame(preemph_features, frame_size, frame_step, False)\n",
        "\n",
        "  # Windowing\n",
        "  window = tf.signal.hann_window(frame_size, periodic=True)\n",
        "  window_features = framed_features * window\n",
        "\n",
        "  # RFFT\n",
        "  fft_size = int(math.pow(2, math.ceil(math.log(frame_size, 2))))\n",
        "  fft_size = max(2048, fft_size)\n",
        "  rfft = tf.signal.rfft(window_features, [fft_size])\n",
        "  magnitude_spectrum = tf.math.abs(rfft)\n",
        "\n",
        "  # Log\n",
        "  output_features = tf.math.log(magnitude_spectrum + 1e-2)\n",
        "  return output_features"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YW0uquqWOyYC"
      },
      "source": [
        "# Load input wav"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ME3Ybzf9VWjJ"
      },
      "source": [
        "### Set path to input wav file:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M4ChO9cwVbpk"
      },
      "outputs": [],
      "source": [
        "wav_file_name = \"p232_118.wav\"\n",
        "# Path to https://github.com/google-research/google-research/tree/master/specinvert/vctk/input\n",
        "wav_path = \"google-research/specinvert/vctk/input/\"\n",
        "wav_path = os.path.join(wav_path, wav_file_name)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f6p9xZPDI26h"
      },
      "outputs": [],
      "source": [
        "wav_data, sample_rate = WavRead(wav_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ii5vLA_1I5Qc"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "plt.plot(wav_data)\n",
        "if use_colabtools:\n",
        "  colabtools.sound.PlaySound(wav_data, sample_rate)\n",
        "else:\n",
        "  IPython.display.Audio(wav_path) "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "91E4-NqTO2u9"
      },
      "source": [
        "# Convert wav to spectrogram"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RE2pJmGmJKa2"
      },
      "outputs": [],
      "source": [
        "spectrogram_magnitude = Wav2Spectrogram(wav_data)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 9,
          "status": "ok",
          "timestamp": 1655766412952,
          "user": {
            "displayName": "Oleg Rybakov",
            "userId": "04792887722985073803"
          },
          "user_tz": 420
        },
        "id": "gZPXQGXAJacF",
        "outputId": "ae3458ef-d42a-47ca-9a05-0a8a98586d1a"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "TensorShape([1, 247, 1025])"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "spectrogram_magnitude.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D1iYyKIrJjuh"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "plt.imshow(spectrogram_magnitude[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eWXxX1wHnOI9"
      },
      "source": [
        "# Prepare models TFlite modules"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s43vrab2nTOZ"
      },
      "outputs": [],
      "source": [
        "# Download TFLite modules\n",
        "# and place them in the current folder of the notebook\n",
        "!wget http://storage.googleapis.com/gresearch/specinvert/non_stream_GAN.tflite\n",
        "!wget http://storage.googleapis.com/gresearch/specinvert/stream_GAN_lookahead1.tflite\n",
        "!wget http://storage.googleapis.com/gresearch/specinvert/stream_GAN_causal.tflite\n",
        "!wget http://storage.googleapis.com/gresearch/specinvert/stream_GL.tflite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3fUOb64IO7sY"
      },
      "source": [
        "# Invert spectrogram with non streaming MelGAN"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Izc14W6AWPhr"
      },
      "outputs": [],
      "source": [
        "non_stream_tfl = RunNonStreaming(spectrogram_magnitude, \"non_stream_GAN.tflite\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zvwum8mBWfo2"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "plt.plot(non_stream_tfl[0])\n",
        "if use_colabtools:\n",
        "  colabtools.sound.PlaySound(non_stream_tfl[0], sample_rate)\n",
        "else:\n",
        "  IPython.display.Audio(non_stream_tfl[0], rate=16000, autoplay=True)   "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PdBMnK8VWRE0"
      },
      "source": [
        "# Invert spectrogram with streaming MelGAN lookahead 1 hop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o7DTDf56WYuf"
      },
      "outputs": [],
      "source": [
        "# Mapping of input output indexes in TFLite\n",
        "inp_to_out_n={}\n",
        "inp_to_out_n[0] = 0\n",
        "stream_lookahead_path_tfl_path = \"stream_GAN_lookahead1.tflite\"\n",
        "output_stream_lookahead_tfl = RunStreaming(spectrogram_magnitude, 1, stream_lookahead_path_tfl_path, inp_to_out_n, input_index=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IybgZuLcWekU"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "plt.plot(output_stream_lookahead_tfl[0])\n",
        "if use_colabtools:\n",
        "  colabtools.sound.PlaySound(output_stream_lookahead_tfl[0], sample_rate)\n",
        "else:\n",
        "  IPython.display.Audio(output_stream_lookahead_tfl[0], rate=16000, autoplay=True)   "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AqqlGK6cWrS7"
      },
      "source": [
        "# Invert spectrogram with streaming causal MelGAN (no lookahead)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZEr7AXEMWu2o"
      },
      "outputs": [],
      "source": [
        "# Mapping of input output indexes in TFLite\n",
        "inp_to_out_n={}\n",
        "inp_to_out_n[0] = 0\n",
        "\n",
        "stream_causal_path_tfl_path = \"stream_GAN_causal.tflite\"\n",
        "output_stream_causal_tfl = RunStreaming(spectrogram_magnitude, 1, stream_causal_path_tfl_path, inp_to_out_n, input_index=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oDCDhBmJWz_s"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "plt.plot(output_stream_causal_tfl[0])\n",
        "if use_colabtools:\n",
        "  colabtools.sound.PlaySound(output_stream_causal_tfl[0], sample_rate)\n",
        "else:\n",
        "  IPython.display.Audio(output_stream_causal_tfl[0], rate=16000, autoplay=True)   "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pif8Qlw8W6Ta"
      },
      "source": [
        "# Invert spectrogram with streaming GL"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u3J0sHHMJxRN"
      },
      "outputs": [],
      "source": [
        "# Mapping of input output indexes in TFLite\n",
        "inp_to_out={}\n",
        "inp_to_out[0] = 2\n",
        "inp_to_out[1] = 3\n",
        "inp_to_out[2] = 0\n",
        "inp_to_out[3] = 1\n",
        "inp_to_out[4] = 4\n",
        "inp_to_out[5] = 5\n",
        "\n",
        "stream_gl_tfl_path = \"stream_GL.tflite\"\n",
        "output_stream_gl_tfl = RunStreaming(spectrogram_magnitude, 1, stream_gl_tfl_path, inp_to_out, input_index=4)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pP96nnuXKGDR"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "plt.plot(output_stream_gl_tfl[0])\n",
        "if use_colabtools:\n",
        "  colabtools.sound.PlaySound(output_stream_gl_tfl[0], sample_rate)\n",
        "else:\n",
        "  IPython.display.Audio(output_stream_gl_tfl[0], rate=16000, autoplay=True)   "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1z3afCgCXBgA"
      },
      "outputs": [],
      "source": [
        ""
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "demo.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
